#!/usr/bin/env python
# coding=utf-8

from __future__ import print_function, division
import torch
from torch.utils.data import Dataset
from os import listdir
from main import n_letters, all_letters

class TextDataset(Dataset):
    def __init__(self, root_dir, transform, target_transform):
        '''
        Args:
            root_dir (string): Directory that holds all cve text files
        '''
        self.root_dir = root_dir
        self.cve_file_list = listdir(root_dir)
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        with open(self.cve_file_list[i]) as f:
            text = f.read().strip()
        input_ = target = text

        input_ = self.transform(input_)  
        target = self.target_transform(target)
        

        return input_, target

    def __len__(self):
        return len(self.cve_file_list)



